# ==============================================================================
# Copyright 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import logging
import random

import numpy as np

from daal4py.sklearn import patch_sklearn

patch_sklearn()

import pathlib
import sys

from sklearn.datasets import load_diabetes, load_iris, make_regression
from sklearn.metrics import pairwise_distances, roc_auc_score

sys.path.append(str(pathlib.Path(__file__).parent.parent.absolute()))
from _models_info import MODELS_INFO, TYPES


def get_class_name(x):
    return x.__class__.__name__


def generate_dataset(name, dtype, model_name):
    if model_name == "LinearRegression":
        X, y = make_regression(n_samples=1000, n_features=5)
    elif name in ["blobs", "classifier"]:
        X, y = load_iris(return_X_y=True)
    elif name == "regression":
        X, y = load_diabetes(return_X_y=True)
    else:
        raise ValueError("Unknown dataset type")
    X = np.array(X, dtype=dtype)
    y = np.array(y, dtype=dtype)
    return (X, y)


def run_patch(model_info, dtype):
    print(get_class_name(model_info["model"]), dtype.__name__)
    X, y = generate_dataset(
        model_info["dataset"], dtype, get_class_name(model_info["model"])
    )
    model = model_info["model"]
    model.fit(X, y)
    logging.info("fit")
    for i in model_info["methods"]:
        if i == "predict":
            model.predict(X)
        elif i == "predict_proba":
            model.predict_proba(X)
        elif i == "predict_log_proba":
            model.predict_log_proba(X)
        elif i == "decision_function":
            model.decision_function(X)
        elif i == "fit_predict":
            model.fit_predict(X)
        elif i == "transform":
            model.transform(X)
        elif i == "fit_transform":
            model.fit_transform(X)
        elif i == "kneighbors":
            model.kneighbors(X)
        elif i == "score":
            model.score(X, y)
        else:
            raise ValueError(i + " is wrong method")
        logging.info(i)


def run_algorithms():
    for info in MODELS_INFO:
        for t in TYPES:
            model_name = get_class_name(info["model"])
            if model_name in ["Ridge", "LinearRegression"] and t.__name__ == "uint32":
                continue
            run_patch(info, t)


def run_utils():
    # pairwise_distances
    for metric in ["cosine", "correlation"]:
        for t in TYPES:
            X = np.random.rand(1000)
            X = np.array(X, dtype=t)
            print("pairwise_distances", t.__name__)
            _ = pairwise_distances(X.reshape(1, -1), metric=metric)
            logging.info("pairwise_distances")
    # roc_auc_score
    for t in [np.float32, np.float64]:
        a = [random.randint(0, 1) for i in range(1000)]
        b = [random.randint(0, 1) for i in range(1000)]
        a = np.array(a, dtype=t)
        b = np.array(b, dtype=t)
        print("roc_auc_score", t.__name__)
        _ = roc_auc_score(a, b)
        logging.info("roc_auc_score")


if __name__ == "__main__":
    run_algorithms()
    run_utils()
